import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt


mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

print("训练train数量:", mnist.train.num_examples)
print("验证validation数量:", mnist.validation.num_examples)
print("测试test数量:", mnist.test.num_examples)

print("train images shape:", mnist.train.images.shape)
print("labels shape:", mnist.train.labels.shape)


def plot_image(image):
    """
    显示图像
    """
    plt.imshow(image.reshape(28, 28), cmap="binary")
    plt.show()


for i in range(100):
    plot_image(mnist.train.images[i])
